#!/usr/bin/env python
"""
parse a MAVLink protocol XML file and generate a python implementation

Copyright Andrew Tridgell 2011
Released under GNU GPL version 3 or later
"""
from __future__ import print_function

from builtins import range

import os
import sys
import textwrap
from . import mavtemplate

t = mavtemplate.MAVTemplate()


def extend_with_type_info(extended, enable_type_annotations):
    types = {
        "int": ("int", 0),
        "bool": ("bool", False),
        "float": ("float", 0),
        "str": ("str", ""),
        "bytes": ("bytes", b""),
        "bytearray": ("bytearray", bytearray(b"")),
        "none": ("None", None),
        "object": ("object", None),
        "str_list": ("List[str]", None),
        "int_list": ("List[int]", None),
        "str_float_int": ("Union[str, float, int]", None),
        "any": ("Any", None),
        "mavlink": ('"MAVLink"', None),
        "mavlink_header": ("MAVLink_header", None),
        "mavlink_message": ("MAVLink_message", None),
        "mavlink_message_type": ('Type[MAVLink_message]', None),
        "mavlink_message_attr": ("Union[bytes, float, int]", None),
        "mavlink_message_assign_attr_list": (
            "List[Union[bytes, float, int, Sequence[float], Sequence[int]]]",
            None,
        ),
        "mavlink_message_list": ("List[MAVLink_message]", None),
        "mavlink_message_signed_callback": ('Callable[["MAVLink", int], bool]', None),
        "dict_str_to_str_float_int": ("Dict[str, Union[str, float, int]]", None),
        "dict_str_to_dict_int_to_enumentry": ("Dict[str, Dict[int, EnumEntry]]", None),
        "dict_int_to_str": ("Dict[int, str]", None),
        "dict_str_to_str": ("Dict[str, str]", None),
        "dict_int_int_int_to_int": ("Dict[Tuple[int, int, int], int]", None),
        "dict_int_to_mavlink_message_type": ("Dict[int, Type[MAVLink_message]]", None),
        "tuple_int": ("Tuple[int]", None),
        "tuple_int_int": ("Tuple[int, int]", None),
        "tuple_int_int_int": ("Tuple[int, int, int]", None),
        "tuple_bytes_five_int": ("Tuple[bytes, int, int, int, int, int]", None),
        "tuple_bytes_eight_int": ("Tuple[bytes, int, int, int, int, int, int, int, int]", None),
        "tuple_bytes_int_float_repeat": ("Tuple[Union[bytes, int, float], ...]", None),
        "intseq": ("Sequence[int]", None),
        "intseq_floatseq": ("Union[Sequence[int], Sequence[float]]", None),
        "args": ("Iterable[Any]", None),
        "kwargs": ("Mapping[str, Any]", None),
        "generic_callback": ("Callable[..., None]", None),
    }

    res = extended
    if enable_type_annotations:
        res[
            "typing_imports"
        ] = """from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union, cast"""
        for type_name, type_info in types.items():
            res["type_" + type_name] = ": " + type_info[0]
            res["type_" + type_name + "_ret"] = " -> " + type_info[0]
            res["type_" + type_name + "_cast"] = type_info[0]
            if type_info[1] is not None:
                res["type_" + type_name + "_default"] = (
                    ": " + type_info[0] + " = " + repr(type_info[1])
                )
            res["type_optional_" + type_name] = ": Optional[" + type_info[0] + "]"
            res["type_optional_" + type_name + "_ret"] = " -> Optional[" + type_info[0] + "]"
            res["type_optional_" + type_name + "_cast"] = "Optional[" + type_info[0] + "]"
            res["type_optional_" + type_name + "_default"] = (
                ": Optional[" + type_info[0] + "] = None"
            )

    else:
        res[
            "typing_imports"
        ] = '''

def cast(type_str, arg):
    """
    No-op for Python2 used instead of typing.cast()
    """
    return arg
'''
        for type_name, type_info in types.items():
            res["type_" + type_name] = ""
            res["type_" + type_name + "_ret"] = ""
            res["type_" + type_name + "_cast"] = '"' + type_info[0] + '"'
            if type_info[1] is not None:
                res["type_" + type_name + "_default"] = "=" + repr(type_info[1])
            res["type_optional_" + type_name] = ""
            res["type_optional_" + type_name + "_ret"] = ""
            res["type_optional_" + type_name + "_cast"] = '"Optional[' + type_info[0] + ']"'
            res["type_optional_" + type_name + "_default"] = "=None"

    return res


def generate_preamble(outf, msgs, basename, args, xml):
    print("Generating preamble")

    params = dict(xml)
    params["FILELIST"] = (",".join(args),)
    params["DIALECT"] = os.path.splitext(os.path.basename(basename))[0]

    t.write(
        outf,
        '''
"""
MAVLink protocol implementation (auto-generated by mavgen.py)

Generated from: ${FILELIST}

Note: this file has been auto-generated. DO NOT EDIT
"""
import hashlib
import json
import logging
import os
import struct
import sys
import time
from builtins import object, range
${typing_imports}

WIRE_PROTOCOL_VERSION = "${wire_protocol_version}"
DIALECT = "${DIALECT}"

PROTOCOL_MARKER_V1 = 0xFE
PROTOCOL_MARKER_V2 = 0xFD
HEADER_LEN_V1 = 6
HEADER_LEN_V2 = 10

MAVLINK_SIGNATURE_BLOCK_LEN = 13

MAVLINK_IFLAG_SIGNED = 0x01

if sys.version_info[0] == 2:
    logging.basicConfig()

logger = logging.getLogger(__name__)

# allow MAV_IGNORE_CRC=1 to ignore CRC, allowing some
# corrupted msgs to be seen
MAVLINK_IGNORE_CRC = os.environ.get("MAV_IGNORE_CRC", 0)

# some base types from mavlink_types.h
MAVLINK_TYPE_CHAR = 0
MAVLINK_TYPE_UINT8_T = 1
MAVLINK_TYPE_INT8_T = 2
MAVLINK_TYPE_UINT16_T = 3
MAVLINK_TYPE_INT16_T = 4
MAVLINK_TYPE_UINT32_T = 5
MAVLINK_TYPE_INT32_T = 6
MAVLINK_TYPE_UINT64_T = 7
MAVLINK_TYPE_INT64_T = 8
MAVLINK_TYPE_FLOAT = 9
MAVLINK_TYPE_DOUBLE = 10


class x25crc(object):
    """CRC-16/MCRF4XX - based on checksum.h from mavlink library"""

    def __init__(self, buf${type_optional_intseq_default})${type_none_ret}:
        self.crc = 0xFFFF
        if buf is not None:
            self.accumulate(buf)

    def accumulate(self, buf${type_intseq})${type_none_ret}:
        """add in some more bytes (it also accepts python2 strings)"""
        if sys.version_info[0] == 2 and type(buf) is str:
            buf = bytearray(buf)

        accum = self.crc
        for b in buf:
            tmp = b ^ (accum & 0xFF)
            tmp = (tmp ^ (tmp << 4)) & 0xFF
            accum = (accum >> 8) ^ (tmp << 8) ^ (tmp << 3) ^ (tmp >> 4)
        self.crc = accum


class MAVLink_header(object):
    """MAVLink message header"""

    def __init__(self, msgId${type_int}, incompat_flags${type_int_default}, compat_flags${type_int_default}, mlen${type_int_default}, seq${type_int_default}, srcSystem${type_int_default}, srcComponent${type_int_default})${type_none_ret}:
        self.mlen = mlen
        self.seq = seq
        self.srcSystem = srcSystem
        self.srcComponent = srcComponent
        self.msgId = msgId
        self.incompat_flags = incompat_flags
        self.compat_flags = compat_flags

    def pack(self, force_mavlink1${type_bool_default})${type_bytes_ret}:
        if float(WIRE_PROTOCOL_VERSION) == 2.0 and not force_mavlink1:
            return struct.pack(
                "<BBBBBBBHB",
                ${protocol_marker},
                self.mlen,
                self.incompat_flags,
                self.compat_flags,
                self.seq,
                self.srcSystem,
                self.srcComponent,
                self.msgId & 0xFFFF,
                self.msgId >> 16,
            )
        return struct.pack(
            "<BBBBBB",
            PROTOCOL_MARKER_V1,
            self.mlen,
            self.seq,
            self.srcSystem,
            self.srcComponent,
            self.msgId,
        )


class MAVLink_message(object):
    """base MAVLink message class"""

    id = 0
    msgname = ""
    fieldnames${type_str_list} = []
    ordered_fieldnames${type_str_list} = []
    fieldtypes${type_str_list} = []
    fielddisplays_by_name${type_dict_str_to_str} = {}
    fieldenums_by_name${type_dict_str_to_str} = {}
    fieldunits_by_name${type_dict_str_to_str} = {}
    native_format = bytearray(b"")
    orders${type_int_list} = []
    lengths${type_int_list} = []
    array_lengths${type_int_list} = []
    crc_extra = 0
    unpacker = struct.Struct("")
    instance_field${type_optional_str} = None
    instance_offset = -1

    def __init__(self, msgId${type_int}, name${type_str})${type_none_ret}:
        self._header = MAVLink_header(msgId)
        self._payload${type_optional_bytes} = None
        self._msgbuf = bytearray(b"")
        self._crc${type_optional_int} = None
        self._fieldnames${type_str_list} = []
        self._type = name
        self._signed = False
        self._link_id${type_optional_int} = None
        self._instances${type_optional_dict_str_to_str} = None
        self._instance_field${type_optional_str} = None

    def format_attr(self, field${type_str})${type_str_float_int_ret}:
        """override field getter"""
        raw_attr = cast(${type_mavlink_message_attr_cast}, getattr(self, field))
        if isinstance(raw_attr, bytes):
            if sys.version_info[0] == 2:
                return raw_attr.rstrip(b"\\x00")
            return raw_attr.decode(errors="backslashreplace").rstrip("\\x00")
        return raw_attr

    def get_msgbuf(self)${type_bytearray_ret}:
        return self._msgbuf

    def get_header(self)${type_mavlink_header_ret}:
        return self._header

    def get_payload(self)${type_optional_bytes_ret}:
        return self._payload

    def get_crc(self)${type_optional_int_ret}:
        return self._crc

    def get_fieldnames(self)${type_str_list_ret}:
        return self._fieldnames

    def get_type(self)${type_str_ret}:
        return self._type

    def get_msgId(self)${type_int_ret}:
        return self._header.msgId

    def get_srcSystem(self)${type_int_ret}:
        return self._header.srcSystem

    def get_srcComponent(self)${type_int_ret}:
        return self._header.srcComponent

    def get_seq(self)${type_int_ret}:
        return self._header.seq

    def get_signed(self)${type_bool_ret}:
        return self._signed

    def get_link_id(self)${type_optional_int_ret}:
        return self._link_id

    def __str__(self)${type_str_ret}:
        ret = "%s {" % self._type
        for a in self._fieldnames:
            v = self.format_attr(a)
            ret += "%s : %s, " % (a, v)
        ret = ret[0:-2] + "}"
        return ret

    def __ne__(self, other${type_object})${type_bool_ret}:
        return not self.__eq__(other)

    def __eq__(self, other${type_object})${type_bool_ret}:
        if other is None:
            return False

        if not isinstance(other, MAVLink_message):
            return False

        if self.get_type() != other.get_type():
            return False

        if self.get_crc() != other.get_crc():
            return False

        if self.get_seq() != other.get_seq():
            return False

        if self.get_srcSystem() != other.get_srcSystem():
            return False

        if self.get_srcComponent() != other.get_srcComponent():
            return False

        for a in self._fieldnames:
            if self.format_attr(a) != other.format_attr(a):
                return False

        return True

    def to_dict(self)${type_dict_str_to_str_float_int_ret}:
        d${type_dict_str_to_str_float_int} = {}
        d["mavpackettype"] = self._type
        for a in self._fieldnames:
            d[a] = self.format_attr(a)
        return d

    def to_json(self)${type_str_ret}:
        return json.dumps(self.to_dict())

    def sign_packet(self, mav${type_mavlink})${type_none_ret}:
        assert mav.signing.secret_key is not None

        h = hashlib.new("sha256")
        self._msgbuf += struct.pack("<BQ", mav.signing.link_id, mav.signing.timestamp)[:7]
        h.update(mav.signing.secret_key)
        h.update(self._msgbuf)
        sig = h.digest()[:6]
        self._msgbuf += sig
        mav.signing.timestamp += 1

    def _pack(self, mav${type_mavlink}, crc_extra${type_int}, payload${type_bytes}, force_mavlink1${type_bool_default})${type_bytes_ret}:
        plen = len(payload)
        if float(WIRE_PROTOCOL_VERSION) == 2.0 and not force_mavlink1:
            # in MAVLink2 we can strip trailing zeros off payloads. This allows for simple
            # variable length arrays and smaller packets
            if sys.version_info[0] == 2:
                nullbyte = chr(0)
            else:
                nullbyte = 0
            while plen > 1 and payload[plen - 1] == nullbyte:
                plen -= 1
        self._payload = payload[:plen]
        incompat_flags = 0
        if mav.signing.sign_outgoing:
            incompat_flags |= MAVLINK_IFLAG_SIGNED
        self._header = MAVLink_header(
            self._header.msgId,
            incompat_flags=incompat_flags,
            compat_flags=0,
            mlen=len(self._payload),
            seq=mav.seq,
            srcSystem=mav.srcSystem,
            srcComponent=mav.srcComponent,
        )
        self._msgbuf = bytearray(self._header.pack(force_mavlink1=force_mavlink1))
        self._msgbuf += self._payload
        crc = x25crc(self._msgbuf[1:])
        if ${crc_extra}:
            # we are using CRC extra
            crc.accumulate(struct.pack("B", crc_extra))
        self._crc = crc.crc
        self._msgbuf += struct.pack("<H", self._crc)
        if mav.signing.sign_outgoing and not force_mavlink1:
            self.sign_packet(mav)
        return bytes(self._msgbuf)

    def pack(self, mav${type_mavlink}, force_mavlink1${type_bool_default})${type_bytes_ret}:
        raise NotImplementedError("MAVLink_message cannot be serialized directly")

    def __getitem__(self, key${type_str})${type_str_ret}:
        """support indexing, allowing for multi-instance sensors in one message"""
        if self._instances is None:
            raise IndexError()
        if key not in self._instances:
            raise IndexError()
        return self._instances[key]


class mavlink_msg_deprecated_name_property(object):
    """
    This handles the class variable name change from name to msgname for
    subclasses of MAVLink_message during a transition period.

    This is used by setting the class variable to
    `mavlink_msg_deprecated_name_property()`.
    """

    def __get__(self, instance${type_optional_mavlink_message}, owner${type_mavlink_message_type})${type_str_ret}:
        if instance is not None:
            logger.error("Using .name on a MAVLink_message is not supported, use .get_type() instead.")
            raise AttributeError("Class {} has no attribute 'name'".format(owner.__name__))
        logger.warning(
            """Using .name on a MAVLink_message class is deprecated, consider using .msgname instead.
Note that if compatibility with pymavlink 2.4.30 and earlier is desired, use something like this:

msg_name =  msg.msgname if hasattr(msg, "msgname") else msg.name"""
        )
        return owner.msgname

''',
        params,
    )


def generate_enums(outf, enums, enable_type_annotations):
    print("Generating enums")

    type_info = extend_with_type_info({}, enable_type_annotations)

    t.write(
        outf,
        """

# enums


class EnumEntry(object):
    def __init__(self, name${type_str}, description${type_str})${type_none_ret}:
        self.name = name
        self.description = description
        self.param${type_dict_int_to_str} = {}
        self.has_location = False


enums${type_dict_str_to_dict_int_to_enumentry} = {}
""",
        type_info,
    )

    for e in enums:
        outf.write("\n# %s\n" % e.name)
        outf.write('enums["%s"] = {}\n' % e.name)
        for entry in e.entry:
            outf.write("%s = %u\n" % (entry.name, entry.value))
            description = entry.description.replace("\t", "    ")
            if "\n" in description:
                outf.write(
                    'enums["%s"][%d] = EnumEntry(\n    "%s",\n    """%s""",\n)\n'
                    % (e.name, int(entry.value), entry.name, description)
                )
            else:
                outf.write(
                    'enums["%s"][%d] = EnumEntry("%s", """%s""")\n'
                    % (e.name, int(entry.value), entry.name, description)
                )
            if entry.has_location:
                outf.write('enums["%s"][%d].has_location = True\n' %
                           (e.name, int(entry.value),))
            for param in entry.param:
                description = param.description.replace("\t", "    ")
                if "\n" in description:
                    outf.write(
                        'enums["%s"][%d].param[\n    %d\n] = """%s"""\n'
                        % (e.name, int(entry.value), int(param.index), description)
                    )
                else:
                    outf.write(
                        'enums["%s"][%d].param[%d] = """%s"""\n'
                        % (e.name, int(entry.value), int(param.index), description)
                    )


def generate_message_ids(outf, msgs):
    print("Generating message IDs")
    outf.write("\n# message IDs\n")
    outf.write("MAVLINK_MSG_ID_BAD_DATA = -1\n")
    outf.write("MAVLINK_MSG_ID_UNKNOWN = -2\n")
    for m in msgs:
        outf.write("MAVLINK_MSG_ID_%s = %u\n" % (m.name.upper(), m.id))


def byname_hash_from_field_attribute(m, attribute):
    strings = []
    for field in m.fields:
        value = getattr(field, attribute, None)
        if value is None or value == "":
            continue
        if attribute == "units":
            # hack; remove the square brackets further up
            if value[0] == "[":
                value = value[1:-1]
        strings.append('"%s": "%s"' % (field.name, value))
    return ", ".join(strings)


def generate_classes(outf, msgs, enable_type_annotations):
    print("Generating class definitions")
    wrapper = textwrap.TextWrapper(initial_indent="    ", subsequent_indent="    ")
    for m in msgs:
        classname = "MAVLink_%s_message" % m.name.lower()
        fieldname_str = ", ".join(['"%s"' % s for s in m.fieldnames])
        ordered_fieldname_str = ", ".join(['"%s"' % s for s in m.ordered_fieldnames])
        fielddisplays_str = byname_hash_from_field_attribute(m, "display")
        fieldenums_str = byname_hash_from_field_attribute(m, "enum")
        fieldunits_str = byname_hash_from_field_attribute(m, "units")

        fieldtypes_str = ", ".join(['"%s"' % s for s in m.fieldtypes])
        if m.instance_field is not None:
            instance_field = '"%s"' % m.instance_field
            instance_offset = m.field_offsets[m.instance_field]
        else:
            instance_field = "None"
            instance_offset = -1

        arg_fields = []
        for i in range(len(m.fields)):
            fname = m.fieldnames[i]
            if m.extensions_start is not None and i >= m.extensions_start:
                fdefault = m.fielddefaults[i]
                if enable_type_annotations:
                    arg_fields.append("%s: %s = %s" % (fname, mavpytype(m.fields[i]), fdefault))
                else:
                    arg_fields.append("%s=%s" % (fname, fdefault))
            else:
                if enable_type_annotations:
                    arg_fields.append("%s: %s" % (fname, mavpytype(m.fields[i])))
                else:
                    arg_fields.append(fname)

        init_fields = []
        for f in m.fields:
            if f.type == "char":
                init_fields.append("self._%s_raw = %s" % (f.name, f.name))
                init_fields.append('self.%s = %s.split(b"\\x00", 1)[0].decode("ascii", errors="replace")' % (f.name, f.name))
            else:
                init_fields.append("self.%s = %s" % (f.name, f.name))

        pack_fields = []
        for field in m.ordered_fields:
            if field.type == "char":
                pack_fields.append("self._{0:s}_raw".format(field.name))
            elif field.array_length == 0:
                pack_fields.append("self.{0:s}".format(field.name))
            else:
                for i in range(field.array_length):
                    pack_fields.append("self.{0:s}[{1:d}]".format(field.name, i))

        t.write(
            outf,
            '''


class ${classname}(MAVLink_message):
    """
${docstring}
    """

    id = MAVLINK_MSG_ID_${msg_name_upper}
    msgname = "${msg_name_upper}"
    fieldnames = [${field_names}]
    ordered_fieldnames = [${ordered_field_names}]
    fieldtypes = [${field_types}]
    fielddisplays_by_name${type_dict_str_to_str} = {${field_displays}}
    fieldenums_by_name${type_dict_str_to_str} = {${field_nums}}
    fieldunits_by_name${type_dict_str_to_str} = {${field_units}}
    native_format = bytearray(b"${native_fmtstr}")
    orders = ${orders}
    lengths = ${lengths}
    array_lengths = ${array_lengths}
    crc_extra = ${crc_extra}
    unpacker = struct.Struct("${fmtstr}")
    instance_field = ${instance_field}
    instance_offset = ${instance_offset}

    def __init__(self, ${arg_fields}):
        MAVLink_message.__init__(self, ${classname}.id, ${classname}.msgname)
        self._fieldnames = ${classname}.fieldnames
        self._instance_field = ${classname}.instance_field
        self._instance_offset = ${classname}.instance_offset
        ${init_fields}

    def pack(self, mav${type_mavlink}, force_mavlink1${type_bool_default})${type_bytes_ret}:
        return self._pack(mav, self.crc_extra, self.unpacker.pack(${pack_fields}), force_mavlink1=force_mavlink1)


# Define name on the class for backwards compatibility (it is now msgname).
# Done with setattr to hide the class variable from mypy.
setattr(${classname}, "name", mavlink_msg_deprecated_name_property())
''',
            extend_with_type_info(
                {
                    "classname": classname,
                    "docstring": wrapper.fill(m.description.strip()),
                    "msg_name_upper": m.name.upper(),
                    "field_names": fieldname_str,
                    "ordered_field_names": ordered_fieldname_str,
                    "field_types": fieldtypes_str,
                    "field_displays": fielddisplays_str,
                    "field_nums": fieldenums_str,
                    "field_units": fieldunits_str,
                    "fmtstr": m.fmtstr,
                    "native_fmtstr": m.native_fmtstr,
                    "orders": m.order_map,
                    "lengths": m.len_map,
                    "array_lengths": m.array_len_map,
                    "crc_extra": m.crc_extra,
                    "instance_field": instance_field,
                    "instance_offset": instance_offset,
                    "arg_fields": ", ".join(arg_fields),
                    "init_fields": "\n        ".join(init_fields),
                    "pack_fields": ", ".join(pack_fields),
                },
                enable_type_annotations,
            ),
        )


def native_mavfmt(field):
    """Work out the struct format for a type."""
    map = {
        "float": "f",
        "double": "d",
        "char": "c",
        "int8_t": "b",
        "uint8_t": "B",
        "uint8_t_mavlink_version": "v",
        "int16_t": "h",
        "uint16_t": "H",
        "int32_t": "i",
        "uint32_t": "I",
        "int64_t": "q",
        "uint64_t": "Q",
    }
    return map[field.type]


def mavfmt(field):
    """work out the struct format for a type"""
    map = {
        "float": "f",
        "double": "d",
        "char": "c",
        "int8_t": "b",
        "uint8_t": "B",
        "uint8_t_mavlink_version": "B",
        "int16_t": "h",
        "uint16_t": "H",
        "int32_t": "i",
        "uint32_t": "I",
        "int64_t": "q",
        "uint64_t": "Q",
    }

    if field.array_length:
        if field.type == "char":
            return str(field.array_length) + "s"
        return str(field.array_length) + map[field.type]
    return map[field.type]


def mavpytype(field):
    c_type_to_py = {
        "float": "float",
        "double": "float",
        "char": "bytes",
        "int8_t": "int",
        "uint8_t": "int",
        "uint8_t_mavlink_version": "int",
        "int16_t": "int",
        "uint16_t": "int",
        "int32_t": "int",
        "uint32_t": "int",
        "int64_t": "int",
        "uint64_t": "int",
    }

    if field.array_length:
        if field.type == "char":
            return "bytes"
        return "Sequence[{}]".format(c_type_to_py[field.type])
    return c_type_to_py[field.type]


def mavdefault(field):
    """returns default value for field (as string) for mavlink2 extensions"""
    if field.type == "char":
        return 'b""'
    else:
        if field.array_length == 0:
            return "0"
        else:
            return "(" + ", ".join(["0"] * field.array_length) + ")"


def generate_mavlink_class(outf, msgs, xml):
    print("Generating MAVLink class")

    t.write(
        outf,
        """


mavlink_map${type_dict_int_to_mavlink_message_type} = {
""",
        xml,
    )
    for m in msgs:
        outf.write(
            "    MAVLINK_MSG_ID_%s: MAVLink_%s_message,\n" % (m.name.upper(), m.name.lower())
        )
    outf.write("}\n\n")

    t.write(
        outf,
        '''

class MAVError(Exception):
    """MAVLink error class"""

    def __init__(self, msg${type_str})${type_none_ret}:
        Exception.__init__(self, msg)
        self.message = msg


class MAVLink_bad_data(MAVLink_message):
    """
    a piece of bad data in a mavlink stream
    """

    def __init__(self, data${type_bytes}, reason${type_str})${type_none_ret}:
        MAVLink_message.__init__(self, MAVLINK_MSG_ID_BAD_DATA, "BAD_DATA")
        self._fieldnames = ["data", "reason"]
        self.data = data
        self.reason = reason
        self._msgbuf = bytearray(data)
        self._instance_field = None

    def __str__(self)${type_str_ret}:
        """Override the __str__ function from MAVLink_messages because non-printable characters are common in to be the reason for this message to exist."""
        if sys.version_info[0] == 2:
            hexstr = ["{:x}".format(ord(i)) for i in self.data]
        else:
            hexstr = ["{:x}".format(i) for i in self.data]
        return "%s {%s, data:%s}" % (self._type, self.reason, hexstr)


class MAVLink_unknown(MAVLink_message):
    """
    a message that we don't have in the XML used when built
    """

    def __init__(self, msgid${type_int}, data${type_bytes})${type_none_ret}:
        MAVLink_message.__init__(self, MAVLINK_MSG_ID_UNKNOWN, "UNKNOWN_%u" % msgid)
        self._fieldnames = ["data"]
        self.data = data
        self._msgbuf = bytearray(data)
        self._instance_field = None

    def __str__(self)${type_str_ret}:
        """Override the __str__ function from MAVLink_messages because non-printable characters are common."""
        if sys.version_info[0] == 2:
            hexstr = ["{:x}".format(ord(i)) for i in self.data]
        else:
            hexstr = ["{:x}".format(i) for i in self.data]
        return "%s {data:%s}" % (self._type, hexstr)


class MAVLinkSigning(object):
    """MAVLink signing state class"""

    def __init__(self)${type_none_ret}:
        self.secret_key${type_optional_bytes} = None
        self.timestamp = 0
        self.link_id = 0
        self.sign_outgoing = False
        self.allow_unsigned_callback${type_optional_mavlink_message_signed_callback} = None
        self.stream_timestamps${type_dict_int_int_int_to_int} = {}
        self.sig_count = 0
        self.badsig_count = 0
        self.goodsig_count = 0
        self.unsigned_count = 0
        self.reject_count = 0


class MAVLink(object):
    """MAVLink protocol handling class"""

    def __init__(self, file${type_any}, srcSystem${type_int_default}, srcComponent${type_int_default}, use_native${type_bool_default})${type_none_ret}:
        self.seq = 0
        self.file = file
        self.srcSystem = srcSystem
        self.srcComponent = srcComponent
        self.callback${type_optional_generic_callback} = None
        self.callback_args${type_optional_args} = None
        self.callback_kwargs${type_optional_kwargs} = None
        self.send_callback${type_optional_generic_callback} = None
        self.send_callback_args${type_optional_args} = None
        self.send_callback_kwargs${type_optional_kwargs} = None
        self.buf = bytearray()
        self.buf_index = 0
        self.expected_length = HEADER_LEN_V1 + 2
        self.have_prefix_error = False
        self.robust_parsing = False
        self.protocol_marker = ${protocol_marker}
        self.little_endian = ${little_endian}
        self.crc_extra = ${crc_extra}
        self.sort_fields = ${sort_fields}
        self.total_packets_sent = 0
        self.total_bytes_sent = 0
        self.total_packets_received = 0
        self.total_bytes_received = 0
        self.total_receive_errors = 0
        self.startup_time = time.time()
        self.signing = MAVLinkSigning()
        self.mav20_unpacker = struct.Struct("<cBBBBBBHB")
        self.mav10_unpacker = struct.Struct("<cBBBBB")
        self.mav20_h3_unpacker = struct.Struct("BBB")
        self.mav_csum_unpacker = struct.Struct("<H")
        self.mav_sign_unpacker = struct.Struct("<IH")

    def set_callback(self, callback${type_generic_callback}, *args${type_any}, **kwargs${type_any})${type_none_ret}:
        self.callback = callback
        self.callback_args = args
        self.callback_kwargs = kwargs

    def set_send_callback(self, callback${type_generic_callback}, *args${type_any}, **kwargs${type_any})${type_none_ret}:
        self.send_callback = callback
        self.send_callback_args = args
        self.send_callback_kwargs = kwargs

    def send(self, mavmsg${type_mavlink_message}, force_mavlink1${type_bool_default})${type_none_ret}:
        """send a MAVLink message"""
        buf = mavmsg.pack(self, force_mavlink1=force_mavlink1)
        self.file.write(buf)
        self.seq = (self.seq + 1) % 256
        self.total_packets_sent += 1
        self.total_bytes_sent += len(buf)
        if self.send_callback is not None and self.send_callback_args is not None and self.send_callback_kwargs is not None:
            self.send_callback(mavmsg, *self.send_callback_args, **self.send_callback_kwargs)

    def buf_len(self)${type_int_ret}:
        return len(self.buf) - self.buf_index

    def bytes_needed(self)${type_int_ret}:
        """return number of bytes needed for next parsing stage"""
        ret = self.expected_length - self.buf_len()

        if ret <= 0:
            return 1
        return ret

    def __callbacks(self, msg${type_mavlink_message})${type_none_ret}:
        """this method exists only to make profiling results easier to read"""
        if self.callback is not None and self.callback_args is not None and self.callback_kwargs is not None:
            self.callback(msg, *self.callback_args, **self.callback_kwargs)

    def parse_char(self, c${type_intseq})${type_optional_mavlink_message_ret}:
        """input some data bytes, possibly returning a new message"""
        self.buf.extend(c)

        self.total_bytes_received += len(c)

        m = self.__parse_char_legacy()

        if m is not None:
            self.total_packets_received += 1
            self.__callbacks(m)
        else:
            # XXX The idea here is if we've read something and there's nothing left in
            # the buffer, reset it to 0 which frees the memory
            if self.buf_len() == 0 and self.buf_index != 0:
                self.buf = bytearray()
                self.buf_index = 0

        return m

    def __parse_char_legacy(self)${type_optional_mavlink_message_ret}:
        """input some data bytes, possibly returning a new message"""
        header_len = HEADER_LEN_V1
        if self.buf_len() >= 1 and self.buf[self.buf_index] == PROTOCOL_MARKER_V2:
            header_len = HEADER_LEN_V2

        m${type_optional_mavlink_message} = None
        if self.buf_len() >= 1 and self.buf[self.buf_index] != PROTOCOL_MARKER_V1 and self.buf[self.buf_index] != PROTOCOL_MARKER_V2:
            magic = self.buf[self.buf_index]
            self.buf_index += 1
            if self.robust_parsing:
                m = MAVLink_bad_data(bytearray([magic]), "Bad prefix")
                self.expected_length = header_len + 2
                self.total_receive_errors += 1
                return m
            if self.have_prefix_error:
                return None
            self.have_prefix_error = True
            self.total_receive_errors += 1
            raise MAVError("invalid MAVLink prefix '%s'" % magic)
        self.have_prefix_error = False
        if self.buf_len() >= 3:
            sbuf = self.buf[self.buf_index : 3 + self.buf_index]
            (magic, self.expected_length, incompat_flags) = cast(
                ${type_tuple_int_int_int_cast},
                self.mav20_h3_unpacker.unpack(sbuf),
            )
            if magic == PROTOCOL_MARKER_V2 and (incompat_flags & MAVLINK_IFLAG_SIGNED):
                self.expected_length += MAVLINK_SIGNATURE_BLOCK_LEN
            self.expected_length += header_len + 2
        if self.expected_length >= (header_len + 2) and self.buf_len() >= self.expected_length:
            mbuf = self.buf[self.buf_index : self.buf_index + self.expected_length]
            self.buf_index += self.expected_length
            self.expected_length = header_len + 2
            if self.robust_parsing:
                try:
                    if magic == PROTOCOL_MARKER_V2 and (incompat_flags & ~MAVLINK_IFLAG_SIGNED) != 0:
                        raise MAVError("invalid incompat_flags 0x%x 0x%x %u" % (incompat_flags, magic, self.expected_length))
                    m = self.decode(mbuf)
                except MAVError as reason:
                    m = MAVLink_bad_data(mbuf, reason.message)
                    self.total_receive_errors += 1
            else:
                if magic == PROTOCOL_MARKER_V2 and (incompat_flags & ~MAVLINK_IFLAG_SIGNED) != 0:
                    raise MAVError("invalid incompat_flags 0x%x 0x%x %u" % (incompat_flags, magic, self.expected_length))
                m = self.decode(mbuf)
            return m
        return None

    def parse_buffer(self, s${type_intseq})${type_optional_mavlink_message_list_ret}:
        """input some data bytes, possibly returning a list of new messages"""
        m = self.parse_char(s)
        if m is None:
            return None
        ret = [m]
        while True:
            m = self.parse_char(b"")
            if m is None:
                return ret
            ret.append(m)

    def check_signature(self, msgbuf${type_bytearray}, srcSystem${type_int}, srcComponent${type_int})${type_bool_ret}:
        """check signature on incoming message"""
        assert self.signing.secret_key is not None

        timestamp_buf = msgbuf[-12:-6]
        link_id = msgbuf[-13]
        (tlow, thigh) = cast(
            ${type_tuple_int_int_cast},
            self.mav_sign_unpacker.unpack(timestamp_buf),
        )
        timestamp = tlow + (thigh << 32)

        # see if the timestamp is acceptable
        stream_key = (link_id, srcSystem, srcComponent)
        if stream_key in self.signing.stream_timestamps:
            if timestamp <= self.signing.stream_timestamps[stream_key]:
                # reject old timestamp
                logger.info("old timestamp")
                return False
        else:
            # a new stream has appeared. Accept the timestamp if it is at most
            # one minute behind our current timestamp
            if timestamp + 6000 * 1000 < self.signing.timestamp:
                logger.info("bad new stream %s %s", timestamp / (100.0 * 1000 * 60 * 60 * 24 * 365), self.signing.timestamp / (100.0 * 1000 * 60 * 60 * 24 * 365))
                return False
            self.signing.stream_timestamps[stream_key] = timestamp
            logger.info("new stream")

        h = hashlib.new("sha256")
        h.update(self.signing.secret_key)
        h.update(msgbuf[:-6])
        sig1 = h.digest()[:6]
        sig2 = msgbuf[-6:]
        if sig1 != sig2:
            logger.info("sig mismatch")
            return False

        # the timestamp we next send with is the max of the received timestamp and
        # our current timestamp
        self.signing.timestamp = max(self.signing.timestamp, timestamp)
        return True

    def decode(self, msgbuf${type_bytearray})${type_mavlink_message_ret}:
        """decode a buffer as a MAVLink message"""
        # decode the header
        if msgbuf[0] != PROTOCOL_MARKER_V1:
            headerlen = 10
            try:
                magic, mlen, incompat_flags, compat_flags, seq, srcSystem, srcComponent, msgIdlow, msgIdhigh = cast(
                    ${type_tuple_bytes_eight_int_cast},
                    self.mav20_unpacker.unpack(msgbuf[:headerlen]),
                )
            except struct.error as emsg:
                raise MAVError("Unable to unpack MAVLink header: %s" % emsg)
            msgId = msgIdlow | (msgIdhigh << 16)
            mapkey = msgId
        else:
            headerlen = 6
            try:
                magic, mlen, seq, srcSystem, srcComponent, msgId = cast(
                    ${type_tuple_bytes_five_int_cast},
                    self.mav10_unpacker.unpack(msgbuf[:headerlen]),
                )
                incompat_flags = 0
                compat_flags = 0
            except struct.error as emsg:
                raise MAVError("Unable to unpack MAVLink header: %s" % emsg)
            mapkey = msgId
        if (incompat_flags & MAVLINK_IFLAG_SIGNED) != 0:
            signature_len = MAVLINK_SIGNATURE_BLOCK_LEN
        else:
            signature_len = 0

        if ord(magic) != PROTOCOL_MARKER_V1 and ord(magic) != PROTOCOL_MARKER_V2:
            raise MAVError("invalid MAVLink prefix '{}'".format(hex(ord(magic))))
        if mlen != len(msgbuf) - (headerlen + 2 + signature_len):
            raise MAVError("invalid MAVLink message length. Got %u expected %u, msgId=%u headerlen=%u" % (len(msgbuf) - (headerlen + 2 + signature_len), mlen, msgId, headerlen))

        if mapkey not in mavlink_map:
            return MAVLink_unknown(msgId, msgbuf)

        # decode the payload
        msgtype = mavlink_map[mapkey]
        order_map = msgtype.orders
        len_map = msgtype.lengths
        crc_extra = msgtype.crc_extra

        # decode the checksum
        try:
            (crc,) = cast(
                ${type_tuple_int_cast},
                self.mav_csum_unpacker.unpack(msgbuf[-(2 + signature_len) :][:2]),
            )
        except struct.error as emsg:
            raise MAVError("Unable to unpack MAVLink CRC: %s" % emsg)
        crcbuf = msgbuf[1 : -(2 + signature_len)]
        if ${crc_extra}:
            # using CRC extra
            crcbuf.append(crc_extra)
        crc2 = x25crc(crcbuf)
        if crc != crc2.crc and not MAVLINK_IGNORE_CRC:
            raise MAVError("invalid MAVLink CRC in msgID %u 0x%04x should be 0x%04x" % (msgId, crc, crc2.crc))

        sig_ok = False
        if signature_len == MAVLINK_SIGNATURE_BLOCK_LEN:
            self.signing.sig_count += 1
        if self.signing.secret_key is not None:
            accept_signature = False
            if signature_len == MAVLINK_SIGNATURE_BLOCK_LEN:
                sig_ok = self.check_signature(msgbuf, srcSystem, srcComponent)
                accept_signature = sig_ok
                if sig_ok:
                    self.signing.goodsig_count += 1
                else:
                    self.signing.badsig_count += 1
                if not accept_signature and self.signing.allow_unsigned_callback is not None:
                    accept_signature = self.signing.allow_unsigned_callback(self, msgId)
                    if accept_signature:
                        self.signing.unsigned_count += 1
                    else:
                        self.signing.reject_count += 1
            elif self.signing.allow_unsigned_callback is not None:
                accept_signature = self.signing.allow_unsigned_callback(self, msgId)
                if accept_signature:
                    self.signing.unsigned_count += 1
                else:
                    self.signing.reject_count += 1
            if not accept_signature:
                raise MAVError("Invalid signature")

        csize = msgtype.unpacker.size
        mbuf = msgbuf[headerlen : -(2 + signature_len)]
        if len(mbuf) < csize:
            # zero pad to give right size
            mbuf.extend([0] * (csize - len(mbuf)))
        if len(mbuf) < csize:
            raise MAVError("Bad message of type %s length %u needs %s" % (msgtype, len(mbuf), csize))
        mbuf = mbuf[:csize]
        try:
            t = cast(
                ${type_tuple_bytes_int_float_repeat_cast},
                msgtype.unpacker.unpack(mbuf),
            )
        except struct.error as emsg:
            raise MAVError("Unable to unpack MAVLink payload type=%s payloadLength=%u: %s" % (msgtype, len(mbuf), emsg))

        tlist${type_mavlink_message_assign_attr_list} = list(t)
        # handle sorted fields
        if ${sort_fields}:
            if sum(len_map) == len(len_map):
                # message has no arrays in it
                for i in range(0, len(tlist)):
                    tlist[i] = t[order_map[i]]
            else:
                # message has some arrays
                tlist = []
                for i in range(0, len(order_map)):
                    order = order_map[i]
                    L = len_map[order]
                    tip = sum(len_map[:order])
                    field = t[tip]
                    if L == 1 or isinstance(field, bytes):
                        tlist.append(field)
                    else:
                        tlist.append(cast(${type_intseq_floatseq_cast}, list(t[tip : (tip + L)])))

        # terminate any strings
        for i, elem in enumerate(tlist):
            if isinstance(elem, bytes):
                tlist[i] = elem.rstrip(b"\\x00")

        # construct the message object
        try:
            # Note that initializers don't follow the Liskov Substitution Principle
            # therefore it can't be typechecked
            m = msgtype(*tlist)  # type: ignore
        except Exception as emsg:
            raise MAVError("Unable to instantiate MAVLink message of type %s : %s" % (msgtype, emsg))
        m._signed = sig_ok
        if m._signed:
            m._link_id = msgbuf[-13]
        m._msgbuf = msgbuf
        m._payload = msgbuf[6 : -(2 + signature_len)]
        m._crc = crc
        m._header = MAVLink_header(msgId, incompat_flags, compat_flags, mlen, seq, srcSystem, srcComponent)
        return m
''',
        xml,
    )


def generate_methods(outf, msgs, enable_type_annotations):
    print("Generating methods")

    def field_descriptions(fields):
        ret = ""
        for f in fields:
            field_info = ""
            if f.units:
                field_info += "%s " % f.units
            field_info += "(type:%s" % f.type
            if f.enum:
                field_info += ", values:%s" % f.enum
            field_info += ")"
            ret += "        %-18s        : %s %s\n" % (
                f.name,
                f.description.strip(),
                field_info,
            )
        return ret

    wrapper = textwrap.TextWrapper(initial_indent="", subsequent_indent="        ")

    for m in msgs:
        comment = "%s\n\n%s" % (wrapper.fill(m.description.strip()), field_descriptions(m.fields))

        field_names = []
        for i in range(len(m.fields)):
            f = m.fields[i]
            if enable_type_annotations:
                python_type = mavpytype(f)
                if f.omit_arg:
                    field_names.append("%s: %s = %s" % (f.name, python_type, f.const_value))
                elif m.extensions_start is not None and i >= m.extensions_start:
                    fdefault = m.fielddefaults[i]
                    field_names.append("%s: %s = %s" % (f.name, python_type, fdefault))
                else:
                    field_names.append("%s: %s" % (f.name, python_type))
            else:
                if f.omit_arg:
                    field_names.append("%s=%s" % (f.name, f.const_value))
                elif m.extensions_start is not None and i >= m.extensions_start:
                    fdefault = m.fielddefaults[i]
                    field_names.append("%s=%s" % (f.name, fdefault))
                else:
                    field_names.append("%s" % f.name)

        self_ret_type = ""
        if enable_type_annotations:
            self_ret_type = " -> MAVLink_" + m.name.lower() + "_message"

        t.write(
            outf,
            '''

    def ${NAMELOWER}_encode(self, ${ARG_FIELDNAMES})${self_ret_type}:
        """
        ${COMMENT}
        """
        return MAVLink_${NAMELOWER}_message(${FIELDNAMES})

    def ${NAMELOWER}_send(self, ${ARG_FIELDNAMES}, force_mavlink1${type_bool_default})${type_none_ret}:
        """
        ${COMMENT}
        """
        self.send(self.${NAMELOWER}_encode(${FIELDNAMES}), force_mavlink1=force_mavlink1)
''',
            extend_with_type_info(
                {
                    "NAMELOWER": m.name.lower(),
                    "ARG_FIELDNAMES": ", ".join(field_names),
                    "COMMENT": comment,
                    "FIELDNAMES": ", ".join(m.fieldnames),
                    "self_ret_type": self_ret_type,
                },
                enable_type_annotations,
            ),
        )


def generate(basename, xml, enable_type_annotations=False):
    """generate complete python implementation"""
    if basename.endswith(".py"):
        filename = basename
    else:
        filename = basename + ".py"

    msgs = []
    enums = []
    filelist = []
    for x in xml:
        msgs.extend(x.message)
        enums.extend(x.enum)
        filelist.append(os.path.basename(x.filename))

    for m in msgs:
        m.fielddefaults = []
        if xml[0].little_endian:
            m.fmtstr = "<"
        else:
            m.fmtstr = ">"
        m.native_fmtstr = m.fmtstr
        m.instance_field = None
        for f in m.ordered_fields:
            m.fmtstr += mavfmt(f)
            m.fielddefaults.append(mavdefault(f))
            m.native_fmtstr += native_mavfmt(f)
            if f.instance:
                m.instance_field = f.name
        m.order_map = [0] * len(m.fieldnames)
        m.len_map = [0] * len(m.fieldnames)
        m.array_len_map = [0] * len(m.fieldnames)
        for i in range(0, len(m.fieldnames)):
            m.order_map[i] = m.ordered_fieldnames.index(m.fieldnames[i])
            m.array_len_map[i] = m.ordered_fields[i].array_length
        for i in range(0, len(m.fieldnames)):
            n = m.order_map[i]
            m.len_map[n] = m.fieldlengths[i]

    print("Generating %s" % filename)
    outf = open(filename, "w")
    xml = extend_with_type_info(xml[0].__dict__, enable_type_annotations)
    generate_preamble(outf, msgs, basename, filelist, xml)
    generate_enums(outf, enums, enable_type_annotations)
    generate_message_ids(outf, msgs)
    generate_classes(outf, msgs, enable_type_annotations)
    generate_mavlink_class(outf, msgs, xml)
    generate_methods(outf, msgs, enable_type_annotations)
    outf.close()
    print("Generated %s OK" % filename)
